#define ARMA_DONT_PRINT_ERRORS

#include <RcppArmadillo.h>
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>
#include <gsl/gsl_cdf.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::depends(RcppGSL)]]

#include <Ziggurat.h>
// [[Rcpp::depends(RcppZiggurat)]]

static Ziggurat::Ziggurat::Ziggurat zigg;


//[[Rcpp:plugins(cpp11)]]
#include <omp.h>
//[[Rcpp:plugins(openmp)]]

const double log2pi = std::log(2.0 * M_PI);



arma::mat mychol(arma::mat S) {
int n = S.n_rows;

arma::mat result(n, n, arma::fill::eye);

    for (unsigned i = 0; i < n; i++) {
    for (unsigned k = 0; k < i; k++) {
double value = S(i, k);
for (unsigned j = 0; j < k; ++j) 
                value -= result(i, j) * result(k, j);

           result(i, k) = value/result(k, k);

        }

double value = S(i, i);

for (unsigned j = 0; j < i; ++j)
            value -= result(i, j) * result(i, j);

        result(i, i) = std::sqrt(value);
      
}
return(result.t());
}



double truncn(double bound, bool lb, double mu, double sigma){
  
  double c, z, w;
  
  // 1. standardised cut-off c for truncation from below or above 
  if(lb == TRUE){
    c = (bound-mu)/sigma;
  } else{
    c = -(bound-mu)/sigma;
  }
  
  // 2. standardised draw using Geweke's (1991)
  if(c < 0.45){ // normal rejection sampling
   // z = ::Rf_rnorm(0.0,1.0);
	z=zigg.norm();
    while(z < c){
     // z = ::Rf_rnorm(0.0,1.0);
	z=zigg.norm();
    } 
  } else{ // exponential rejection sampling
    z = -log(1-::Rf_runif(0.0,1.0))/c;
    w = ::Rf_runif(0.0,1.0);
    while(w > exp(-0.5*pow(z,2))){
      z = -log(1-::Rf_runif(0.0,1.0))/c;
      w = ::Rf_runif(0.0,1.0);
    }
    z = z+c;
  }

  // 3. reverse standardisation
  if(lb == TRUE){
    return mu + sigma*z;
  } else{
    return mu - sigma*z;
  }
}



 double get1TN(double mu,
 double sd,
 double low,
 double high
 ) {
 double draw = 0.0 ;
int valid = 0 ;
while (valid==0) {
double cand = mu+zigg.norm()*sd;
if ((cand >= low) &
 (cand <= high)
 ) {
 draw = cand ;
 valid = 1 ;
 }
 }
 return(draw) ;
}


double vm_mult(const arma::mat& rhs)
{
  arma::mat ip= rhs.t() * rhs;
 return (ip(0));
}

double vm_convert(const arma::mat& rhs)
{
  arma::mat ip= rhs;
 return (ip(0));
}

double am_mult2(const arma::mat& rhs)
{
  arma::mat ip= rhs.t();
 return (ip(0));
}




arma::mat mvrnormArma(int n, arma::mat mu, arma::mat sigma) {
          int ncols = sigma.n_cols;
          arma::mat Y = arma::randn(n, ncols);
    // arma::mat Z=arma::repmat(mu, 1, n).t() + Y * arma::chol(sigma);
        arma::mat Z=arma::repmat(mu, 1, n).t() + Y * sigma;
 
     return (Z.t());
 }


void inplace_tri_mat_mult(arma::rowvec &x, arma::mat const &trimat){
  arma::uword const n = trimat.n_cols;
  
  for(unsigned j = n; j-- > 0;){
    double tmp(0.);
    for(unsigned i = 0; i <= j; ++i)
      tmp += trimat.at(i, j) * x[i];
    x[j] = tmp;
  }
}

arma::vec dmvnrm_arma(arma::mat const x,  
                           arma::rowvec const mean,  
                           arma::mat const sigma, 
                           bool const logd = false) { 
    using arma::uword;
    uword const n = x.n_rows, 
             xdim = x.n_cols;
    arma::vec out(n);
    //arma::mat const rooti = arma::inv(trimatu(arma::chol(sigma)));
    arma::mat const rooti = arma::inv(trimatu(sigma));
    double const rootisum = arma::sum(log(rooti.diag())), 
                constants = -(double)xdim/2.0 * log2pi, 
              other_terms = rootisum + constants;
    
    arma::rowvec z;
    for (uword i = 0; i < n; i++) {
        z = (x.row(i) - mean);
        inplace_tri_mat_mult(z, rooti);
        out(i) = other_terms - 0.5 * arma::dot(z, z);     
    }  
      
    if (logd)
      return out;
    return exp(out);
}







// [[Rcpp::export()]]
Rcpp::List mcmc_probit (arma::mat Y,
arma::mat I,
arma::vec tablecountry,
arma::mat X,
arma::colvec Cinit_conv,
arma::colvec Cinit_unconv,
arma::mat alphastart,
arma::mat alphastart_conv,
arma::mat alphastart_unconv,
arma::mat beta_start,
arma::mat Cov_Beta,
arma::rowvec densitybeta,
int mcmc = 100,
int burn=10,
int thin=10,
int chains=2
 ) {

 // Dimensions
int N = Y.n_rows ;
int nitems = Y.n_cols;
int ncountry=I.n_cols;
int nx=X.n_cols;

 // Current Containers
int N1_conv=round(N/2);
int N1_unconv=round(N/2);
arma::mat P_conv(2,1);
P_conv.fill(0.5) ;
arma::mat P_unconv(2,1);
P_unconv.fill(0.5) ;
arma::mat ystar(N, nitems) ;
ystar.fill(0.0) ;
 arma::mat mu(N, nitems) ;
 arma::colvec C_conv = Cinit_conv ;
 arma::colvec C_unconv = Cinit_unconv ;
 arma::mat alpha = alphastart ;
 arma::mat alpha_conv = alphastart_conv ;
 arma::mat alpha_unconv = alphastart_unconv ;
arma::mat  probC_conv(N, 1) ;
 arma::mat  probC_unconv(N, 1) ;
 arma::mat  pnum(N, 2) ;
 arma::mat  pnum_unconv(N, 2) ;
arma::mat  p(N, 1) ;
arma::mat  p_unconv(N, 1) ;
 arma::vec sigma2_alphaintercept(nitems);
 sigma2_alphaintercept.fill(1.0) ;

 arma::vec sigma2_alphaconv(nitems);
 sigma2_alphaconv.fill(1.0) ;

 arma::vec sigma2_alphaunconv(nitems);
 sigma2_alphaunconv.fill(1.0) ;

arma::mat y(N,1);
y.fill(0.0);

 arma::mat etaold_beta(N, 4) ;
 etaold_beta.fill(0.0) ;
 arma::mat etanew_beta(N, 4) ;
 etanew_beta.fill(0.0) ;

arma::mat beta2=beta_start;
arma::mat beta3=beta_start;
arma::mat beta4=beta_start;

 arma::mat beta2new(nx,1);
  beta2new.fill(0.0);
  arma::mat beta3new(nx,1);
  beta3new.fill(0.0);
  arma::mat beta4new(nx,1);
  beta4new.fill(0.0);

arma::mat lold_beta(N,1);
arma::mat lnew_beta(N,1);
 lold_beta.fill(0.0);
 lnew_beta.fill(0.0);



arma::mat sigmabetadensity(nx,nx,arma::fill::eye);




 int lastit= (mcmc-burn)/thin	;

// Trace Containers
  arma::cube tracealpha(nitems*ncountry, lastit,chains) ;
  arma::cube tracealphaconv(nitems*ncountry, lastit,chains) ;
  arma::cube tracealphaunconv(nitems*ncountry, lastit,chains) ;
  arma::cube tracesigma2alpha(nitems,lastit,chains);
  arma::cube tracesigma2alphaconv(nitems,lastit,chains);
  arma::cube tracesigma2alphaunconv(nitems,lastit,chains);
  arma::cube tracePunconv(2, lastit,chains) ;
  arma::cube tracePconv(2, lastit,chains) ;
  arma::cube tracebeta2(nx, lastit,chains) ;
  arma::cube tracebeta3(nx, lastit,chains) ;
  arma::cube tracebeta4(nx, lastit,chains) ;

omp_set_num_threads(chains);

#pragma omp parallel for 
for (int chain = 0; chain < chains; ++chain) {

gsl_rng *s = gsl_rng_alloc(gsl_rng_mt19937);

 // START SAMPLING
 for (int iter = 0 ; iter < mcmc ; iter++) {
// Update C_conv
P_conv(0)= gsl_ran_beta(s, N1_conv+4, N-N1_conv+4);
P_conv(1)=1-P_conv(0);

for (int k = 0 ; k < 2 ; k++) {
arma::mat lindividualitem(N,nitems);
for (int l=0; l<nitems; l++) {
arma::mat mutmp=I*alpha.row(l).t()+(I*alpha_conv.row(l).t())*k+(I*alpha_unconv.row(l).t())%(C_unconv-1);
for (int n = 0 ; n < N ; n++) {
lindividualitem(n,l)=pow(gsl_cdf_ugaussian_P(mutmp(n)), Y(n,l))* pow(gsl_cdf_ugaussian_P(-mutmp(n)), 1-Y(n,l));
}
}
pnum.col(k)=prod(lindividualitem,1)*P_conv(k);
}

p=sum(pnum, 1);
probC_conv=pnum.col(1)/p;

for (int n=0; n<N; n++){
C_conv(n)= gsl_ran_binomial(s,probC_conv(n), 1)+1;
}

N1_conv=N-vm_convert(sum(C_conv-1,0));




// Update C_unconv
P_unconv(0)= gsl_ran_beta(s, N1_unconv+4, N-N1_unconv+4);
P_unconv(1)=1-P_unconv(0);

for (int k = 0 ; k < 2 ; k++) {
arma::mat lindividualitem_unconv(N,nitems);

for (int l=0; l<nitems; l++) {
arma::mat mutmp_unconv=I*alpha.row(l).t()+(I*alpha_conv.row(l).t())%(C_conv-1)+(I*alpha_unconv.row(l).t())*k;
for (int n = 0 ; n < N ; n++) {
lindividualitem_unconv(n,l)=pow(gsl_cdf_ugaussian_P(mutmp_unconv(n)), Y(n,l))* pow(gsl_cdf_ugaussian_P(-mutmp_unconv(n)), 1-Y(n,l));
}
}
pnum_unconv.col(k)=prod(lindividualitem_unconv,1)*P_unconv(k);
}
p_unconv=sum(pnum_unconv, 1);
probC_unconv=pnum_unconv.col(1)/p_unconv;

for (int n=0; n<N; n++){
C_unconv(n)=gsl_ran_binomial(s,probC_unconv(n), 1)+1;
}

N1_unconv=N-vm_convert(sum(C_unconv-1,0));


// Update item-specific parameters
for (int l=0; l<nitems; l++) {

mu.col(l) = I*alpha.row(l).t()+(I*alpha_conv.row(l).t())%(C_conv-1)+(I*alpha_unconv.row(l).t())%(C_unconv-1);

for (int n = 0 ; n < N ; n++) {
 if (Y(n, l) == 1) {
 ystar(n, l) = get1TN(mu(n, l),
 1,
 0,
 INFINITY
 ) ;
 }
 if (Y(n, l) == 0) {
 ystar(n, l) = get1TN(mu(n, l),
 1,
 -INFINITY,
 0
 ) ;
 }
 }

arma::mat llreffectalpha=I.t()*(ystar.col(l)-(I*alpha_conv.row(l).t())%(C_conv-1)-(I*alpha_unconv.row(l).t())%(C_unconv-1));


for (int c=0; c<ncountry; c++) {
double valpha=1/(tablecountry(c)+(1/sigma2_alphaintercept(l)));
double balpha=valpha * llreffectalpha(c);
alpha(l,c)= balpha+gsl_ran_gaussian_ziggurat(s,sqrt(valpha));


}


sigma2_alphaintercept(l)=1.0/gsl_ran_gamma(s, 1+ncountry/2, 1.0/(1.0+vm_mult(alpha.row(l))/2));


}


for (int l=0; l<(nitems-1); l++) {
for (int c=0; c<ncountry; c++) {
double valpha_conv=1/(vm_mult(I.col(c)%(C_conv-1))+(1/sigma2_alphaconv(l)));
double balpha_conv=valpha_conv * vm_convert((I.col(c)%(C_conv-1)).t()*(I.col(c)%(ystar.col(l)-I*alpha.row(l).t()-(I*alpha_unconv.row(l).t())%(C_unconv-1))));
alpha_conv(l,c)=truncn(0.0,TRUE, balpha_conv,  sqrt(valpha_conv));
}

sigma2_alphaconv(l)=1.0/gsl_ran_gamma(s, 1+ncountry/2, 1.0/(1.0+vm_mult(alpha_conv.row(l))/2));

}


for (int l=1; l<nitems; l++) {
for (int c=0; c<ncountry; c++){
double valpha_unconv=1/(vm_mult(I.col(c)%(C_unconv-1))+(1/sigma2_alphaunconv(l)));
double balpha_unconv=valpha_unconv * vm_convert((I.col(c)%(C_unconv-1)).t()*(I.col(c)%(ystar.col(l)-I*alpha.row(l).t()-(I*alpha_conv.row(l).t())%(C_conv-1))));
alpha_unconv(l,c)=truncn(0.0,TRUE, balpha_unconv,  sqrt(valpha_unconv));
}

sigma2_alphaunconv(l)=1.0/gsl_ran_gamma(s, 1+ncountry/2, 1.0/(1.0+vm_mult(alpha_unconv.row(l))/2));

}



for (int n = 0 ; n < N ; n++) {
 if (C_conv(n) == 2 && C_unconv(n) == 2) {
y(n)=4;
} else if (C_conv(n)==2 && C_unconv(n) == 1) {
y(n)=3;
} else if (C_conv(n)==1 && C_unconv(n) == 2) {
y(n)=2;
} else if (C_conv(n)==1 && C_unconv(n) == 1) {
y(n)=1;
}
}


etaold_beta.col(0)=1/(1+exp(X*beta2)+exp(X*beta3)+exp(X*beta4));
etaold_beta.col(1)=exp(X*beta2)%etaold_beta.col(0);
etaold_beta.col(2)=exp(X*beta3)%etaold_beta.col(0);
etaold_beta.col(3)=exp(X*beta4)%etaold_beta.col(0);


beta2new=mvrnormArma(1,  beta2, mychol(Cov_Beta));
beta3new=mvrnormArma(1,  beta3, mychol(Cov_Beta));
beta4new=mvrnormArma(1,  beta4, mychol(Cov_Beta));

etanew_beta.col(0)=1/(1+exp(X*beta2new)+exp(X*beta3new)+exp(X*beta4new));
etanew_beta.col(1)=exp(X*beta2new)%etanew_beta.col(0);
etanew_beta.col(2)=exp(X*beta3new)%etanew_beta.col(0);
etanew_beta.col(3)=exp(X*beta4new)%etanew_beta.col(0);

for (int n = 0 ; n < N ; n++) {
 if (y(n) == 1) {
 lold_beta(n)=log(etaold_beta(n,0));
 lnew_beta(n)=log(etanew_beta(n,0));
} else if (y(n)==2) {
 lold_beta(n)=log(etaold_beta(n,1));
 lnew_beta(n)=log(etanew_beta(n,1));
} else if (y(n)==3) {
 lold_beta(n)=log(etaold_beta(n,2));
 lnew_beta(n)=log(etanew_beta(n,2));
} else if (y(n)==4) {
 lold_beta(n)=log(etaold_beta(n,3));
 lnew_beta(n)=log(etanew_beta(n,3));
}
}


double accept_beta= vm_convert(sum(lnew_beta)-sum(lold_beta))+
			am_mult2(dmvnrm_arma(beta2new.t(),densitybeta, sigmabetadensity, true))+
			  am_mult2(dmvnrm_arma(beta3new.t(),densitybeta, sigmabetadensity, true))+
			  am_mult2(dmvnrm_arma(beta4new.t(),densitybeta, sigmabetadensity, true))-
		        am_mult2(dmvnrm_arma(beta2.t(),densitybeta, sigmabetadensity, true))-
		        am_mult2(dmvnrm_arma(beta3.t(),densitybeta, sigmabetadensity, true))-
		        am_mult2(dmvnrm_arma(beta4.t(),densitybeta, sigmabetadensity, true));




if(log(gsl_rng_uniform(s))<accept_beta) {
for (int l=0; l<nx; ++l) {
beta2(l)=beta2new(l);
beta3(l)=beta3new(l);
beta4(l)=beta4new(l);
}
}



 // TRACE
if ((iter+1)> burn & (iter+1)%thin==0) {
int j=((iter)-burn)/thin;
tracealpha.subcube(0, j, chain, nitems*ncountry-1, j, chain)=vectorise(alpha);
tracealphaconv.subcube(0, j, chain, nitems*ncountry-1, j, chain)=vectorise(alpha_conv);
tracealphaunconv.subcube(0, j, chain, nitems*ncountry-1, j, chain)=vectorise(alpha_unconv);
tracesigma2alpha.subcube(0,j,chain,nitems-1,j,chain)=sigma2_alphaintercept;
tracesigma2alphaconv.subcube(0,j,chain,nitems-1,j,chain)=sigma2_alphaconv;
tracesigma2alphaunconv.subcube(0,j,chain,nitems-1,j,chain)=sigma2_alphaunconv;
tracePconv.subcube(0, j, chain, 1, j, chain)=P_conv;
tracePunconv.subcube(0, j, chain, 1, j, chain)=P_unconv;
tracebeta2.subcube(0, j, chain, (nx-1), j, chain)=beta2;
tracebeta3.subcube(0, j, chain, (nx-1), j, chain)=beta3;
tracebeta4.subcube(0, j, chain, (nx-1), j, chain)=beta4;

}



} 
gsl_rng_free(s);
}

// Returns
Rcpp::List ret;

 ret["alpha"] = tracealpha;
 ret["alpha_conv"] = tracealphaconv;
 ret["alpha_unconv"] = tracealphaunconv;
 ret["beta2"] = tracebeta2;
 ret["beta3"] = tracebeta3;
 ret["beta4"] = tracebeta4;
ret["sigma2_alphaintercept"] = tracesigma2alpha;
ret["sigma2_alphaconv"] = tracesigma2alphaconv;
ret["sigma2_alphaunconv"] = tracesigma2alphaunconv;
ret["P_conv"] = tracePconv;
ret["P_unconv"] = tracePunconv;

 return(ret) ;
 

}

